here::i_am("R/RobustEstimation.R")
library(here)
library(readr)
library(ggplot2)
library(tidyr)
library(dplyr)
library(scales)
library(purrr)
library(robustbase) # For MM-estimators
library(MASS) # For rlm() - M-estimators
library(quantreg) # For quantile regression
library(broom) # For tidy output
library(patchwork)
library(jsonlite)
library(ggrepel) # For text repelling in plots

# Source the diagnostic plotting function
source(here("R/DiagnosticPlot.R"))

# Load data-----------------------
macro_data_path <- here("data/tidy/macro_data.csv")
final_data <- read_csv(file = macro_data_path, show_col_types = FALSE)

# 1. Baseline OLS model--------------
baseline_lm_model <- lm(
  avg_dependency_pct ~ trade_pct_gdp + log(gdp_per_capita_ppp) +
    log(population) + total_natural_resources_rents_pct_gdp,
  data = final_data
)

# 2. MM-ESTIMATOR ----------------
## 2.1. MM Estimation--------------
mm_model <- lmrob(
  avg_dependency_pct ~ trade_pct_gdp + log(gdp_per_capita_ppp) +
    log(population) + total_natural_resources_rents_pct_gdp,
  data = final_data, method = "MM"
)
summary(mm_model)

## 2.2. Create custom diagnostic plots for MM estimation------------
# Extract diagnostic data
mm_fitted <- fitted(mm_model)
mm_residuals <- residuals(mm_model)
mm_weights <- mm_model$rweights
mm_standardized_residuals <- mm_residuals / sqrt(var(mm_residuals))

# Create diagnostic data frame
diagnostic_data <- data.frame(
  fitted = mm_fitted,
  residuals = mm_residuals,
  std_residuals = mm_standardized_residuals,
  weights = mm_weights,
  country = final_data$country[complete.cases(
    final_data[c(
      "avg_dependency_pct", "trade_pct_gdp",
      "gdp_per_capita_ppp", "population",
      "total_natural_resources_rents_pct_gdp"
    )]
  )],
  obs_number = 1:length(mm_fitted)
)

### 2.2.1. Manual diagnostic plots for MM (since they have different structure)
# Define common theme
theme_diagnostic <- theme_minimal() +
  theme(
    plot.title = element_text(size = 12, face = "bold"),
    plot.subtitle = element_text(size = 10, color = "grey60"),
    axis.title = element_text(size = 10),
    axis.text = element_text(size = 9),
    legend.position = "none",
    panel.grid.minor = element_blank(),
    strip.text = element_text(size = 10, face = "bold")
  )

### Residuals vs Fitted for MM
p1 <- ggplot(diagnostic_data, aes(x = fitted, y = residuals)) +
  geom_point(size = 2, alpha = 0.7, color = "#1F78B4") +
  geom_smooth(
    method = "loess", se = TRUE, color = "red", 
    linetype = "dashed", alpha = 0.3
    ) +
  geom_hline(
    yintercept = 0, linetype = "solid", 
    color = "grey50", alpha = 0.7
    ) +
  labs(
    title = "Residuals vs Fitted (MM-Estimator)",
    subtitle = "Check for linearity and homoscedasticity",
    x = "Fitted Values",
    y = "Residuals"
  ) +
  theme_diagnostic

### Normal Q-Q plot for MM
p2 <- ggplot(diagnostic_data, aes(sample = std_residuals)) +
  stat_qq(size = 2, alpha = 0.7, color = "#1F78B4") +
  stat_qq_line(color = "red", linetype = "dashed") +
  labs(
    title = "Normal Q-Q (MM-Estimator)",
    subtitle = "Check for normality of residuals",
    x = "Theoretical Quantiles",
    y = "Standardized Residuals"
  ) +
  theme_diagnostic

### Scale-Location plot for MM
p3 <- ggplot(diagnostic_data, aes(x = fitted, y = sqrt(abs(std_residuals)))) +
  geom_point(size = 2, alpha = 0.7, color = "#1F78B4") +
  geom_smooth(
    method = "loess", se = TRUE, color = "red", 
    linetype = "dashed", alpha = 0.3
    ) +
  labs(
    title = "Scale-Location (MM-Estimator)",
    subtitle = "Check for homoscedasticity",
    x = "Fitted Values",
    y = expression(sqrt("|Standardized Residuals|"))
  ) +
  theme_diagnostic

### Robustness Weights Plot (unique to MM estimation)
# Identify low-weight observations
low_weight_threshold <- 0.5
low_weight_countries <- diagnostic_data$country[
  diagnostic_data$weights < low_weight_threshold]

p4 <- ggplot(diagnostic_data, aes(x = obs_number, y = weights)) +
  geom_point(
    aes(color = weights < low_weight_threshold), size = 2, alpha = 0.7
    ) +
  geom_hline(
    yintercept = low_weight_threshold, color = "red", linetype = "dashed"
    ) +
  ggrepel::geom_text_repel(
    data = diagnostic_data[diagnostic_data$weights < low_weight_threshold, ],
    aes(label = country),
    size = 3, color = "red", alpha = 0.8
  ) +
  scale_color_manual(
    values = c("TRUE" = "red", "FALSE" = "#1F78B4"),
    labels = c("Normal", "Downweighted"),
    name = "Status"
  ) +
  labs(
    title = "Robustness Weights (MM-Estimator)",
    subtitle = "Outlier observations downweighted",
    x = "Observation Number",
    y = "Robustness Weight"
  ) +
  theme_diagnostic +
  theme(legend.position = "bottom")

### Combined MM diagnostic plot
mm_combined_plot <- (p1 + p2) / (p3 + p4) +
  plot_layout(guides = "collect") &
  theme(legend.position = "bottom")

ggsave(
  filename = here("output/RobustEstimationDiagnostics.png"),
  dpi = 300, plot = mm_combined_plot, height = 6, width = 10
)

# Print outlier identification
print(paste(
  "Outliers identified by MM-estimator (weight < 0.5):",
  paste(low_weight_countries, collapse = ", ")
))

# 3. OLS with explicit exclusion of outliers--------
# Remove the countries identified by MM-estimator as outliers
outlier_countries <- low_weight_countries
cat(
  "Excluding countries for robustness check:",
  paste(outlier_countries, collapse = ", "), "\n"
)
final_data_no_outliers <- final_data %>%
  filter(!country %in% outlier_countries)
cat("Original sample size:", nrow(final_data), "\n")
cat("Sample size after excluding outliers:", nrow(final_data_no_outliers), "\n")

# Run OLS on cleaned data
exclusion_model <- lm(
  avg_dependency_pct ~ trade_pct_gdp + log(gdp_per_capita_ppp) +
    log(population) + total_natural_resources_rents_pct_gdp,
  data = final_data_no_outliers
)
summary(exclusion_model)

## 3.1. Use standardized diagnostic function for exclusion model--------
exclusion_diagnostic_plots <- create_diagnostic_plots(
  exclusion_model, final_data_no_outliers)

# Combine exclusion diagnostic plots
exclusion_combined_plot <- (
  exclusion_diagnostic_plots$p1 + exclusion_diagnostic_plots$p2) /
  (exclusion_diagnostic_plots$p3 + exclusion_diagnostic_plots$p4) +
  plot_layout(guides = "collect") +
  plot_annotation(
    title = "Diagnostic Plots: OLS with Outliers Excluded",
    subtitle = "Standard regression diagnostics after removing MM-identified outliers"
  )

ggsave(here("output/exclusion_model_diagnostics.png"), exclusion_combined_plot,
  width = 12, height = 8, dpi = 300, bg = "white"
)

cat("✅ Exclusion model diagnostic plots saved to output/exclusion_model_diagnostics.png\n")

## 3.2. Standard errors comparison---------
se_ols <- sqrt(diag(vcov(baseline_lm_model)))
se_mm <- sqrt(diag(vcov(mm_model)))
se_exclusion <- sqrt(diag(vcov(exclusion_model)))

comparison_df <- data.frame(
  Variable = names(coef(baseline_lm_model)),
  OLS_Coef = coef(baseline_lm_model),
  MM_Coef = coef(mm_model),
  Exclusion_Coef = coef(exclusion_model),
  OLS_SE = se_ols,
  MM_SE = se_mm,
  Exclusion_SE = se_exclusion,
  Diff_OLS_MM = abs(coef(baseline_lm_model) - coef(mm_model)),
  Diff_OLS_Exclusion = abs(coef(baseline_lm_model) - coef(exclusion_model)),
  Diff_MM_Exclusion = abs(coef(mm_model) - coef(exclusion_model))
)

# 4. Create diagnostic plots for baseline OLS model using standardized function--------
baseline_diagnostic_plots <- create_diagnostic_plots(
  baseline_lm_model, final_data)

# Combine baseline diagnostic plots
baseline_combined_plot <- (
  baseline_diagnostic_plots$p1 + baseline_diagnostic_plots$p2) /
  (baseline_diagnostic_plots$p3 + baseline_diagnostic_plots$p4) +
  plot_layout(guides = "collect") +
  plot_annotation(
    title = "Diagnostic Plots: Baseline OLS Model",
    subtitle = "Standard regression diagnostics for the original model"
  )

ggsave(here("output/baseline_ols_diagnostics.png"), baseline_combined_plot,
  width = 12, height = 8, dpi = 300, bg = "white"
)

cat("✅ Baseline OLS diagnostic plots saved to output/baseline_ols_diagnostics.png\n")

# 5. Outlier analysis for the QMD integration--------
# Check the actual values for the outliers
outlier_data <- final_data[
  final_data$country %in% outlier_countries,
  c(
    "country", "avg_dependency_pct", "trade_pct_gdp",
    "gdp_per_capita_ppp", "population", "total_natural_resources_rents_pct_gdp"
  )
]

# Check their standardized residuals in OLS
outlier_indices <- which(final_data$country %in% outlier_countries)
ols_residuals <- rstandard(baseline_lm_model)

# Create outlier analysis data frame
outlier_analysis <- data.frame(
  country = outlier_countries,
  standardized_residual = ols_residuals[outlier_indices],
  mm_weight = mm_weights[outlier_indices],
  stringsAsFactors = FALSE
)

# Calculate coefficient differences as percentage of OLS coefficients
coef_diff_percent <- abs(coef(mm_model) - coef(exclusion_model)
                         ) / abs(coef(baseline_lm_model)) * 100

# Create coefficient differences data frame
coef_diff_analysis <- data.frame(
  variable = names(coef(baseline_lm_model)),
  absolute_difference = abs(coef(mm_model) - coef(exclusion_model)),
  percent_difference = coef_diff_percent,
  stringsAsFactors = FALSE
)

# 6. Export results for QMD integration--------
# Helper function to add significance stars
add_stars <- function(p_values) {
  stars <- ifelse(p_values < 0.001, "***",
    ifelse(p_values < 0.01, "**",
      ifelse(p_values < 0.05, "*",
        ifelse(p_values < 0.1, ".", "")
      )
    )
  )
  return(stars)
}

# Extract MM-estimator results
mm_summary <- summary(mm_model)
mm_coef_table <- data.frame(
  variable = rownames(mm_summary$coefficients),
  estimate = mm_summary$coefficients[, "Estimate"],
  std_error = mm_summary$coefficients[, "Std. Error"],
  t_value = mm_summary$coefficients[, "t value"],
  p_value = mm_summary$coefficients[, "Pr(>|t|)"],
  stringsAsFactors = FALSE
)
mm_coef_table$stars <- add_stars(mm_coef_table$p_value)

# Extract exclusion model results
exclusion_coef_table <- data.frame(
  variable = rownames(summary(exclusion_model)$coefficients),
  estimate = summary(exclusion_model)$coefficients[, "Estimate"],
  std_error = summary(exclusion_model)$coefficients[, "Std. Error"],
  t_value = summary(exclusion_model)$coefficients[, "t value"],
  p_value = summary(exclusion_model)$coefficients[, "Pr(>|t|)"],
  stars = add_stars(summary(exclusion_model)$coefficients[, "Pr(>|t|)"]),
  stringsAsFactors = FALSE
)

# Create combined results object
combined_regression_results <- list(
  # OLS results (from existing baseline model)
  ols = list(
    coefficients = data.frame(
      variable = rownames(summary(baseline_lm_model)$coefficients),
      estimate = summary(baseline_lm_model)$coefficients[, "Estimate"],
      std_error = summary(baseline_lm_model)$coefficients[, "Std. Error"],
      t_value = summary(baseline_lm_model)$coefficients[, "t value"],
      p_value = summary(baseline_lm_model)$coefficients[, "Pr(>|t|)"],
      stars = add_stars(summary(baseline_lm_model)$coefficients[, "Pr(>|t|)"]),
      stringsAsFactors = FALSE
    ),
    r_squared = summary(baseline_lm_model)$r.squared,
    adj_r_squared = summary(baseline_lm_model)$adj.r.squared,
    f_statistic = summary(baseline_lm_model)$fstatistic[1],
    f_p_value = pf(summary(baseline_lm_model)$fstatistic[1],
      summary(baseline_lm_model)$fstatistic[2],
      summary(baseline_lm_model)$fstatistic[3],
      lower.tail = FALSE
    ),
    n_obs = length(baseline_lm_model$fitted.values),
    method = "OLS"
  ),
  # MM-estimator results
  mm = list(
    coefficients = mm_coef_table,
    # Calculate robust R-squared equivalent
    r_squared = if ("r.squared" %in% names(mm_summary)) {
      mm_summary$r.squared
    } else {
      y_fitted <- fitted(mm_model)
      y_actual <- final_data$avg_dependency_pct[complete.cases(final_data[c(
        "avg_dependency_pct", "trade_pct_gdp",
        "gdp_per_capita_ppp", "population",
        "total_natural_resources_rents_pct_gdp"
      )])]
      1 - sum((y_actual - y_fitted)^2) / sum((y_actual - median(y_actual))^2)
    },
    adj_r_squared = NA, # Not typically calculated for robust methods
    f_statistic = NA, # Not applicable for MM-estimator
    f_p_value = NA, # Not applicable for MM-estimator
    n_obs = length(fitted(mm_model)),
    method = "MM-Estimator",
    outliers_identified = low_weight_countries,
    outlier_threshold = low_weight_threshold
  ),
  # Exclusion results
  exclusion = list(
    coefficients = exclusion_coef_table,
    r_squared = summary(exclusion_model)$r.squared,
    adj_r_squared = summary(exclusion_model)$adj.r.squared,
    f_statistic = summary(exclusion_model)$fstatistic[1],
    f_p_value = pf(summary(exclusion_model)$fstatistic[1],
      summary(exclusion_model)$fstatistic[2],
      summary(exclusion_model)$fstatistic[3],
      lower.tail = FALSE
    ),
    n_obs = length(exclusion_model$fitted.values),
    method = "OLS (Outliers Excluded)",
    excluded_countries = outlier_countries
  ),
  # Detailed outlier analysis for QMD integration
  outlier_analysis = list(
    outlier_data = outlier_data,
    outlier_statistics = outlier_analysis,
    coefficient_differences = coef_diff_analysis,
    max_standardized_residual = max(abs(outlier_analysis$standardized_residual)),
    min_mm_weight = min(outlier_analysis$mm_weight),
    max_percent_difference = max(coef_diff_analysis$percent_difference),
    intercept_absolute_diff = coef_diff_analysis$absolute_difference[1],
    intercept_percent_diff = coef_diff_analysis$percent_difference[1],
    trade_coef_ols = coef(baseline_lm_model)["trade_pct_gdp"],
    trade_coef_mm = coef(mm_model)["trade_pct_gdp"],
    trade_coef_exclusion = coef(exclusion_model)["trade_pct_gdp"],
    trade_percent_diff = coef_diff_analysis$percent_difference[2]
  ),
  # Comparison metrics
  comparison = list(
    coefficient_differences = comparison_df,
    max_difference_ols_mm = max(
      abs(coef(baseline_lm_model) - coef(mm_model))),
    max_difference_ols_exclusion = max(
      abs(coef(baseline_lm_model) - coef(exclusion_model))),
    max_difference_mm_exclusion = max(
      abs(coef(mm_model) - coef(exclusion_model))),
    mean_difference_ols_mm = mean(
      abs(coef(baseline_lm_model) - coef(mm_model))),
    mean_difference_ols_exclusion = mean(
      abs(coef(baseline_lm_model) - coef(exclusion_model))),
    mean_difference_mm_exclusion = mean(
      abs(coef(mm_model) - coef(exclusion_model))),
    stability_assessment = ifelse(
      max(abs(coef(baseline_lm_model) - coef(mm_model))) < 0.1, "High",
      ifelse(
        max(abs(coef(baseline_lm_model) - coef(mm_model))) < 0.2, 
        "Moderate", "Low")
    )
  )
)

# Save combined results to JSON for QMD integration
write_json(combined_regression_results, 
           here("output/combined_regression_results.json"),
  pretty = TRUE, digits = 6, auto_unbox = TRUE
)
cat("✅ Combined regression results saved to output/combined_regression_results.json\n")

# Save the MM-estimator diagnostic plot
ggsave(here("output/mm_estimator_diagnostics.png"), mm_combined_plot,
  width = 12, height = 8, dpi = 300, bg = "white"
)
cat("✅ MM-estimator diagnostic plots saved to output/mm_estimator_diagnostics.png\n")
